package com.lee.algorithm.tree;

import com.lee.algorithm.tree.struct.TreeNode;

import java.util.HashMap;
import java.util.HashSet;

/***
 * @description: 求同一棵二叉树中两个节点的最低公共祖先节点
 *      节点往上，两个节点最先汇聚的节点则是这两个节点的公共祖先
 * @author 博客园 @ 青石路
 * @date: 2021/12/11 9:49
 */
public class LowestCommonAncestor {

    /**
     * 寻找 n1 和 n2 在树 root 中的最低公共祖先节点
     *
     * 树遍历，用哈希表记录下所有节点的父节点
     * 从 n1 开始，从哈希表中找出 n1 的所有祖先节点存入到 HashSet set1 中
     * 从 n2 开始，从哈希表中逐个遍历 n2 的祖先节点，判断n2祖先节点是否存在于 set1 中
     * 一旦存在直接返回，返回的节点肯定是 n1 与 n2 的最低公共祖先节点
     *
     * @param root 树的根节点
     * @param n1
     * @param n2
     * @return
     * @author 博客园 @ 青石路
     */
    public static TreeNode findLca(TreeNode root, TreeNode n1, TreeNode n2) {
        HashMap<TreeNode, TreeNode> fatherMap = new HashMap<>();  // 记录所有节点的父节点
        fatherMap.put(root, root);      // 根节点的父节点就是它自己
        findAllNodeFather(root, fatherMap);
        HashSet<TreeNode> set1 = new HashSet<>();   // 记录 n1 的所有祖先节点

        TreeNode cur = n1;
        while (cur != fatherMap.get(cur)) {         // 将 n1 所有的祖先节点找出来
            set1.add(cur);
            cur = fatherMap.get(cur);
        }

        cur = n2;
        while (cur != fatherMap.get(cur)) {
            if (set1.contains(cur)) {
                return cur;
            }
            cur = fatherMap.get(cur);
        }
        return null;
    }

    /**
     * 找到所有节点的父节点
     * @param root
     * @param fatherMap
     * @author 博客园 @ 青石路
     */
    private static void findAllNodeFather(TreeNode root, HashMap<TreeNode, TreeNode> fatherMap) {
        if (root == null) {
            return;
        }
        fatherMap.put(root.left, root);
        fatherMap.put(root.right, root);
        findAllNodeFather(root.left, fatherMap);
        findAllNodeFather(root.right, fatherMap);
    }

    /**
     * 很抽象
     * 需要结合代码、两个节点最低公共祖先的情况来理解
     *      最低公共祖先的情况
     *      1）n1是n2的祖先节点，那么返回n1；或者n2是n1的祖先节点，那么返回n2
     *      2）n1与n2不互为祖先节点，那么则向上寻找最初汇聚节点，该节点则是最低公共祖先节点
     * @param root
     * @param n1
     * @param n2
     * @return
     * @author 博客园 @ 青石路
     */
    public static TreeNode findLcaPlus(TreeNode root, TreeNode n1, TreeNode n2) {
        if (root == null || root == n1 || root == n2) {
            return root;
        }

        TreeNode left = findLcaPlus(root.left, n1, n2);
        TreeNode right = findLcaPlus(root.right, n1, n2);
        if (left != null && right != null) {
            return root;    // n1 和 n2 的最初汇聚点，一旦找到，该节点会一直往上抛，最终返回值也是它
        }
        return left != null ? left : right;
    }
}
